{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"></ul></div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected cross entropy loss if the model:\n",
      "- learns neither dependency: 0.661563238158\n",
      "- learns first dependency: 0.519166699707\n",
      "- learns both dependency: 0.454454367449\n"
     ]
    }
   ],
   "source": [
    "print(\"Expected cross entropy loss if the model:\")\n",
    "print(\"- learns neither dependency:\", -(0.625 * np.log(0.625) + 0.375 * np.log(0.375)))\n",
    "print(\"- learns first dependency:\", \n",
    "      -0.5 * (0.875 * np.log(0.875) + 0.125 * np.log(0.125))\n",
    "      -0.5 * (0.625 * np.log(0.625) + 0.375 * np.log(0.375)))\n",
    "print(\"- learns both dependency:\", \n",
    "      -0.5 * (0.75 * np.log(0.75) + 0.25 * np.log(0.25))\n",
    "      -0.25 * (2 * 0.5 * np.log(0.5) - 0.25 * (0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']=''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 10\n",
    "batch_size = 200\n",
    "num_classes = 2\n",
    "state_size = 4\n",
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data(size = 1000000):\n",
    "    X = np.array(np.random.choice(2, size=(size,)))\n",
    "    Y = []\n",
    "    for i in range(size):\n",
    "        threshold = 0.5\n",
    "        if X[i-3] == 1:\n",
    "            threshold += 0.5\n",
    "        if X[i-8] == 1:\n",
    "            threshold -= 0.25\n",
    "        if np.random.rand() > threshold:\n",
    "            Y.append(0)\n",
    "        else:\n",
    "            Y.append(1)\n",
    "    return X, np.array(Y)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_batch(raw_date, batch_size, num_steps):\n",
    "    raw_x, raw_y = raw_date\n",
    "    data_length = len(raw_x)\n",
    "    \n",
    "    #partition raw data into batches and stak them vertically in a data matrix\n",
    "    batch_partition_length = data_length // batch_size\n",
    "    data_x = np.zeros([batch_size, batch_partition_length], dtype=np.int32)\n",
    "    data_y = np.zeros([batch_size, batch_partition_length], dtype=np.int32)\n",
    "    # do partition \n",
    "    for i in range(batch_size):\n",
    "        data_x[i] = raw_x[batch_partition_length * i : batch_partition_length * (i + 1)]\n",
    "        data_y[i] = raw_y[batch_partition_length * i : batch_partition_length * (i + 1)]\n",
    "    # do epoch\n",
    "    epoch_size = batch_partition_length // num_steps\n",
    "    print(\"epoch_size %d\" % epoch_size)\n",
    "    \n",
    "    for i in range(epoch_size):\n",
    "        x = data_x[:, i * num_steps:(i + 1) * num_steps]\n",
    "        y = data_y[:, i * num_steps:(i + 1) * num_steps]\n",
    "        yield(x, y)\n",
    "        \n",
    "def gen_epochs(n, num_steps):\n",
    "    for i in range(n):\n",
    "        yield gen_batch(gen_data(), batch_size, num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x = tf.placeholder(tf.int32, [batch_size, num_steps], name='input_placeholder')\n",
    "y = tf.placeholder(tf.int32, [batch_size, num_steps], name='labels_placeholder')\n",
    "init_state = tf.zeros([batch_size, state_size])\n",
    "\n",
    "x_one_hot = tf.one_hot(x, num_classes)  # [batch_size, num_steps]-->\n",
    "                                        #[batch_size, num_steps, num_classes]\n",
    "rnn_inputs = tf.unstack(x_one_hot, axis = 1) # 为方便计算 生成num_steps个(batch_size,num_classes)矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.variable_scope('rnn_cell'):\n",
    "    W = tf.get_variable('W', [num_classes + state_size, state_size])\n",
    "    b = tf.get_variable('b', [state_size], initializer=tf.constant_initializer(0.0))\n",
    "\n",
    "def rnn_cell(rnn_input, state):\n",
    "    with tf.variable_scope('rnn_cell', reuse=True):\n",
    "        W = tf.get_variable('W', [num_classes + state_size, state_size])\n",
    "        b = tf.get_variable('b', [state_size], initializer=tf.constant_initializer(0.0))\n",
    "    return tf.tanh(tf.matmul(tf.concat([rnn_input, state], 1), W) + b)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = init_state # 因为S(t+1)不存在 所以置为0\n",
    "rnn_outputs = []\n",
    "for rnn_input in rnn_inputs:  # rnn_inputs是num_steps个(batch_size,num_classes)矩阵\n",
    "    state = rnn_cell(rnn_input, state)  # get the state \n",
    "    rnn_outputs.append(state)\n",
    "final_state = rnn_outputs[-1]   # 在做完num_steps次前向传播之后，将最后的state赋值给final_state\n",
    "                                # 以供下次training_step循环time_step时，传承记忆"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.variable_scope('softmax'):\n",
    "    W = tf.get_variable('W', [state_size, num_classes])\n",
    "    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))\n",
    "logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs]\n",
    "predictions = [tf.nn.softmax(logit) for logit in logits]\n",
    "\n",
    "y_as_list = tf.unstack(y, num=num_steps, axis=1)\n",
    "\n",
    "losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=logit) for \n",
    "         logit, label in zip(logits, y_as_list)]\n",
    "total_loss = tf.reduce_mean(losses)\n",
    "train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "EPOCH 0\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.87347663343\n",
      "Average loss at step 200 for last 100 steps: 0.710601933599\n",
      "Average loss at step 300 for last 100 steps: 0.650836793184\n",
      "Average loss at step 400 for last 100 steps: 0.626289330721\n",
      "\n",
      "EPOCH 1\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.590729051232\n",
      "Average loss at step 200 for last 100 steps: 0.55818700254\n",
      "Average loss at step 300 for last 100 steps: 0.537134252191\n",
      "Average loss at step 400 for last 100 steps: 0.526792359352\n",
      "\n",
      "EPOCH 2\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.52745998472\n",
      "Average loss at step 200 for last 100 steps: 0.51907700181\n",
      "Average loss at step 300 for last 100 steps: 0.519309648275\n",
      "Average loss at step 400 for last 100 steps: 0.520301951468\n",
      "\n",
      "EPOCH 3\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.525068391562\n",
      "Average loss at step 200 for last 100 steps: 0.519198945165\n",
      "Average loss at step 300 for last 100 steps: 0.520553269088\n",
      "Average loss at step 400 for last 100 steps: 0.518290593028\n",
      "\n",
      "EPOCH 4\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.522011487186\n",
      "Average loss at step 200 for last 100 steps: 0.518377505541\n",
      "Average loss at step 300 for last 100 steps: 0.516962650716\n",
      "Average loss at step 400 for last 100 steps: 0.516596245766\n",
      "\n",
      "EPOCH 5\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.523461107612\n",
      "Average loss at step 200 for last 100 steps: 0.516289199293\n",
      "Average loss at step 300 for last 100 steps: 0.517943150699\n",
      "Average loss at step 400 for last 100 steps: 0.518050109446\n",
      "\n",
      "EPOCH 6\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.522919949889\n",
      "Average loss at step 200 for last 100 steps: 0.519882829189\n",
      "Average loss at step 300 for last 100 steps: 0.518844670951\n",
      "Average loss at step 400 for last 100 steps: 0.51680441618\n",
      "\n",
      "EPOCH 7\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.522926489115\n",
      "Average loss at step 200 for last 100 steps: 0.516543312073\n",
      "Average loss at step 300 for last 100 steps: 0.51757014364\n",
      "Average loss at step 400 for last 100 steps: 0.515813337564\n",
      "\n",
      "EPOCH 8\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.522675927281\n",
      "Average loss at step 200 for last 100 steps: 0.51491961658\n",
      "Average loss at step 300 for last 100 steps: 0.514735060334\n",
      "Average loss at step 400 for last 100 steps: 0.515167931616\n",
      "\n",
      "EPOCH 9\n",
      "epoch_size 500\n",
      "Average loss at step 100 for last 100 steps: 0.52209299773\n",
      "Average loss at step 200 for last 100 steps: 0.515804851055\n",
      "Average loss at step 300 for last 100 steps: 0.515071320236\n",
      "Average loss at step 400 for last 100 steps: 0.516001387239\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fecf848e390>]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAIABJREFUeJzt3Xt0XOV57/HvoxmNJOtiXW1sy7bkW4wTwIDsAoZcIDiGJJCsXGpoWkhIOO2qSZq0PYXVnoSQ1XPSs5qTZPX4JCUpoblwC2kTQ0kJNCYkXC3AONjGxndJvkmyZcm6j+Y5f8yWGYRsje2Rtzzz+6w1a2a/s7fm8bb9m1fv3vvd5u6IiEhuyAu7ABEROXMU+iIiOUShLyKSQxT6IiI5RKEvIpJDFPoiIjlEoS8ikkMU+iIiOUShLyKSQ6JhFzBSdXW119XVhV2GiMhZ5aWXXmpz95qx1ptwoV9XV0djY2PYZYiInFXMbHc662l4R0Qkhyj0RURyiEJfRCSHKPRFRHKIQl9EJIco9EVEcohCX0Qkh2RN6Hf2DfLNJ7ayvqkj7FJERCasrAl9T8C3/+sNGncdCrsUEZEJK2tCv6woSn7EaDs6EHYpIiITVtaEvplRVVxA+9H+sEsREZmwsib0AapKYrR3q6cvInI8WRb66umLiJxIVoV+dUlMY/oiIieQZaFfQNvRftw97FJERCakrAr9quIY/fEE3QNDYZciIjIhZVfolxQA0NalcX0RkdFkVehXl8QAaO9W6IuIjCat0DezFWa2xcy2mdnto7w/y8zWmtkrZrbBzK4N2uvMrNfM1geP72b6D5Cqerinr4O5IiKjGvMeuWYWAVYDVwPNwDozW+Pum1JW+zvgIXf/jpktAh4D6oL3trv74syWPbqq4Z6+Ql9EZFTp9PSXAtvcfYe7DwAPANePWMeBsuD1ZGBv5kpMX2VxMvTbdK6+iMio0gn9GUBTynJz0JbqTuBTZtZMspd/W8p79cGwz2/M7IrRPsDMbjWzRjNrbG1tTb/6EQqiEcoKo7pAS0TkODJ1IPcG4F53rwWuBX5kZnnAPmCWu18IfAm4z8zKRm7s7ne7e4O7N9TU1JxWIdUlBbRpKgYRkVGlE/otwMyU5dqgLdUtwEMA7v4cUAhUu3u/u7cH7S8B24EFp1v0iVSVxNTTFxE5jnRCfx0w38zqzSwGrATWjFhnD3AVgJmdSzL0W82sJjgQjJnNAeYDOzJV/GiSV+Wqpy8iMpoxQ9/d48Aq4HFgM8mzdDaa2V1mdl2w2l8CnzOzV4H7gZs9ORfCu4ENZrYeeBj4U3cf17ucqKcvInJ8Y56yCeDuj5E8QJva9uWU15uAZaNs9zPgZ6dZ40mpKi7gcM8g8aEE0UhWXXsmInLasi4Vh6/KPdSjIR4RkZGyMPSH599R6IuIjJR1oT886Zrm3xERebssDH1NxSAicjxZF/rVxcOTrqmnLyIyUtaFfllRlPyI6Vx9EZFRZF3omxlVxbpBuojIaLIu9CG4QEvz74iIvE2Whr56+iIio8nK0K8uiWlMX0RkFFka+gW0He0nOf2PiIgMy8rQryqO0R9P0D0wFHYpIiITSlaG/vBUDBrXFxF5q6wM/eGrcnWBlojIW2Vl6B+bdE0Hc0VE3iIrQ1/z74iIjC47Q79YY/oiIqPJytCPRfMoK4xqTF9EZIS0Qt/MVpjZFjPbZma3j/L+LDNba2avmNkGM7s25b07gu22mNkHMln8iVSXFNCmqRhERN5izHvkmlkEWA1cDTQD68xsTXBf3GF/R/KG6d8xs0Uk76dbF7xeCbwTmA48aWYL3H3cT6DXDdJFRN4unZ7+UmCbu+9w9wHgAeD6Ees4UBa8ngzsDV5fDzzg7v3uvhPYFvy8cVddUqADuSIiI6QT+jOAppTl5qAt1Z3Ap8ysmWQv/7aT2HZcVJXENKYvIjJCpg7k3gDc6+61wLXAj8ws7Z9tZreaWaOZNba2tmakoKriAg73DBIfSmTk54mIZIN0grkFmJmyXBu0pboFeAjA3Z8DCoHqNLfF3e929wZ3b6ipqUm/+hOoDs7VP9SjIR4RkWHphP46YL6Z1ZtZjOSB2TUj1tkDXAVgZueSDP3WYL2VZlZgZvXAfODFTBV/Iseuyu1S6IuIDBvz7B13j5vZKuBxIALc4+4bzewuoNHd1wB/CXzPzL5I8qDuzZ6c13ijmT0EbALiwJ+fiTN3IHkjFYD2bo3ri4gMGzP0Adz9MZIHaFPbvpzyehOw7Djb/j3w96dR4ynRVAwiIm+XlVfkQuqka+rpi4gMy9rQLyuMkh8xzbQpIpIia0PfzKgq1g3SRURSZW3oQzAVg+bfERE5JqtDPzkVg3r6IiLDsjr0k1MxqKcvIjIsq0O/uqSAtqP9JC8ZEBGRrA79quIY/fEE3QNn5HowEZEJL6tDf/hcfY3ri4gkZXXoD1+Vqwu0RESSsjr037wqVwdzRUQgy0Nf8++IiLxVdod+scb0RURSZXXox6J5lBVGNaYvIhLI6tCH4Fx9TcUgIgLkSOhreEdEJCnrQ7+qJKYDuSIigZwIfY3pi4gkZX/oFxdwuGeQ+FAi7FJEREKXVuib2Qoz22Jm28zs9lHe/6aZrQ8eW82sI+W9oZT31mSy+HRUlyZP2zzUoyEeEZExb4xuZhFgNXA10AysM7M1wc3QAXD3L6asfxtwYcqP6HX3xZkr+eRUF795gdaU0sKwyhARmRDS6ekvBba5+w53HwAeAK4/wfo3APdnorhMqNIN0kVEjkkn9GcATSnLzUHb25jZbKAe+HVKc6GZNZrZ82b2keNsd2uwTmNra2uapadHUzGIiLwp0wdyVwIPu3vqBPaz3b0BuBH4lpnNHbmRu9/t7g3u3lBTU5PRgqrV0xcROSad0G8BZqYs1wZto1nJiKEdd28JnncAT/HW8f5xV1YYJT9iukG6iAjphf46YL6Z1ZtZjGSwv+0sHDNbCFQAz6W0VZhZQfC6GlgGbBq57XgyM6qKC2jrUk9fRGTMs3fcPW5mq4DHgQhwj7tvNLO7gEZ3H/4CWAk84G+9Ie25wD+bWYLkF8zXU8/6OVOqSmLq6YuIkEboA7j7Y8BjI9q+PGL5zlG2exY47zTqywjNvyMikpT1V+TC8FQM6umLiORE6FeXFNB2tJ+3jjyJiOSeHAn9GP3xBN0DQ2OvLCKSxXIi9HXbRBGRpNwI/eCqXF2gJSK5LidC/82rcnUwV0RyW06FvubfEZFclxOhX3lsemUN74hIbsuJ0I9F8ygrjGpMX0RyXk6EPgTn6msqBhHJcTkV+hreEZFclzOhX1US04FcEcl5ORX6GtMXkVyXO6FfXMDhnkHiQ4mwSxERCU3OhH51afJc/UM9GuIRkdyVO6EfnKvf1qXQF5HclTOhX1ddDMDr+ztDrkREJDw5E/oLppZSWhhl3a5DYZciIhKatELfzFaY2RYz22Zmt4/y/jfNbH3w2GpmHSnv3WRmbwSPmzJZ/MmI5BkNsytYt+twWCWIiIRuzHvkmlkEWA1cDTQD68xsTeoNzt39iynr3wZcGLyuBL4CNAAOvBRsG0ryLqmvZO2WLRzqHjg2H4+ISC5Jp6e/FNjm7jvcfQB4ALj+BOvfANwfvP4A8IS7HwqC/glgxekUfDqW1FUCaIhHRHJWOqE/A2hKWW4O2t7GzGYD9cCvT3bbM+H82snEonk0KvRFJEdl+kDuSuBhdz+pm9Ga2a1m1mhmja2trRku6U0F0QiLa8t5UeP6IpKj0gn9FmBmynJt0Daalbw5tJP2tu5+t7s3uHtDTU1NGiWduoa6Cja2HKFnID6unyMiMhGlE/rrgPlmVm9mMZLBvmbkSma2EKgAnktpfhxYbmYVZlYBLA/aQrOkvpJ4wlm/p2PslUVEssyYoe/ucWAVybDeDDzk7hvN7C4zuy5l1ZXAA+7uKdseAr5G8otjHXBX0Baai2dXYAYvalxfRHLQmKdsArj7Y8BjI9q+PGL5zuNsew9wzynWl3FlhfksPKdMZ/CISE7KmStyUy2tq+CVPR0MasZNEckxORn6S+or6RkYYtNezcMjIrklN0NfF2mJSI7KydCfWlbIrMpJCn0RyTk5GfqQ7O037jpMyslGIiJZL4dDv4L27gG2t3aHXYqIyBmTu6FfnxzX1zw8IpJLcjb051QXU1Uc00VaIpJTcjb0zYyGugodzBWRnJKzoQ/Jg7lNh3rZf6Qv7FJERM6InA79pfU6X19EcktOh/6iaWVMikUU+iKSM3I69KORPC6apZuli0juyOnQh+S4/uv7OznSOxh2KSIi406hX1eBO7y8W719Ecl+OR/6F86qIJpnGtcXkZyQ86FfFIvwrhmTFfoikhNyPvQhOcTzatMR+gaHwi5FRGRcKfRJHswdGErw+5YjYZciIjKu0gp9M1thZlvMbJuZ3X6cdT5pZpvMbKOZ3ZfSPmRm64PHmkwVnkkNwU1VXtypIR4RyW5j3hjdzCLAauBqoBlYZ2Zr3H1TyjrzgTuAZe5+2MympPyIXndfnOG6M6qyOMa8KSUa1xeRrJdOT38psM3dd7j7APAAcP2IdT4HrHb3wwDufjCzZY6/JXWVvLTrsMb1RSSrpRP6M4CmlOXmoC3VAmCBmT1jZs+b2YqU9wrNrDFo/8hoH2BmtwbrNLa2tp7UHyBTrrtgOl39cf7ldztD+XwRkTMhUwdyo8B84L3ADcD3zKw8eG+2uzcANwLfMrO5Izd297vdvcHdG2pqajJU0sm5dG4VyxdNZfXabRzo1KybIpKd0gn9FmBmynJt0JaqGVjj7oPuvhPYSvJLAHdvCZ53AE8BF55mzePmbz94LvEh5x/+8/WwSxERGRfphP46YL6Z1ZtZDFgJjDwL5+cke/mYWTXJ4Z4dZlZhZgUp7cuATUxQs6uK+czl9fzbyy28skfTMohI9hkz9N09DqwCHgc2Aw+5+0Yzu8vMrgtWexxoN7NNwFrgr929HTgXaDSzV4P2r6ee9TMRrbpyHjWlBXz1kU0kEh52OSIiGWXuEyvYGhoavLGxMdQaftrYxF8/vIFv/uEFfPTC2lBrERFJh5m9FBw/PSFdkTuKj11Uy/m1k/n6L1+nuz8edjkiIhmj0B9FXp7xlQ8v4kBnP9/9zfawyxERyRiF/nFcPLuS6xdP55+f3kHToZ6wyxERyQiF/gncfs1CImb8r19uDrsUEZGMUOifwLTJRfzpe+by2O/38/yO9rDLERE5bQr9Mdz67jnMKC/iq49sYkincIrIWU6hP4aiWIQ7rl3I5n2dPLiuaewNREQmMIV+Gj543jSW1FXw7f/aSnwoEXY5IiKnTKGfBjPjs1fM4UBnP2u3hDMLqIhIJij003TlwilMKS3g/hf3hF2KiMgpU+inKT+SxycbZvLUloO0dPSGXY6IyClR6J+EP1wyEwce0gFdETlLKfRPwszKSVwxv4aHGpt0QFdEzkoK/ZN049KZ7DvSx2+26oCuiJx9FPon6apzp1JdogO6InJ2UuifpOQB3Vp+/fpB9h3RAV0RObso9E/ByiWzSDg8tK457FJERE6KQv8UzKqaxBXzq3lw3R7NxyMiZ5W0Qt/MVpjZFjPbZma3H2edT5rZJjPbaGb3pbTfZGZvBI+bMlV42G5YOou9R/p4Wgd0ReQsEh1rBTOLAKuBq4FmYJ2ZrUm9wbmZzQfuAJa5+2EzmxK0VwJfARoAB14Ktj2c+T/KmfX+c6dSXRLjvhf38L6FU8IuR0QkLen09JcC29x9h7sPAA8A149Y53PA6uEwd/eDQfsHgCfc/VDw3hPAisyUHq5YNI+PXzyTX79+kP1H+sIuR0QkLemE/gwg9RLU5qAt1QJggZk9Y2bPm9mKk9j2rLVyyUyGEs5PG3WFroicHTJ1IDcKzAfeC9wAfM/MytPd2MxuNbNGM2tsbT17xsjrqotZNq+KB9Y16YCuiJwV0gn9FmBmynJt0JaqGVjj7oPuvhPYSvJLIJ1tcfe73b3B3RtqampOpv7Q3bB0Fi0dvfz2jbPny0pEclc6ob8OmG9m9WYWA1YCa0as83OSvXzMrJrkcM8O4HFguZlVmFkFsDxoyxrLF51DVXFMV+iKyFlhzNB39ziwimRYbwYecveNZnaXmV0XrPY40G5mm4C1wF+7e7u7HwK+RvKLYx1wV9CWNZIHdGt5cvNBDnbqgK6ITGzmPrHGohsaGryxsTHsMk7KjtajXPmN3/BXyxew6sr5YZcjIjnIzF5y94ax1tMVuRkwp6aEZfOq+MkLexjUlMsiMoEp9DPk5svq2Xekj8c37g+7FBGR41LoZ8iVC6cwu2oS9/xuZ9iliIgcl0I/QyJ5xk2X1vHyng5ebeoIuxwRkVEp9DPoEw21lBRE+cEz6u2LyMSk0M+g0sJ8PtFQy6Mb9nFAp2+KyASk0M+wmy+rY8idHz+/O+xSRETeRqGfYbOrirlq4VR+8sIe+gaHwi5HROQtFPrj4DPL6jjUPcCa9XvDLkVE5C0U+uPg0rlVLDynlHue2clEu+JZRHKbQn8cmBmfXlbH6/u7eG5He9jliIgco9AfJ9cvnkFlcYwfPLMr7FJERI5R6I+TwvwINy6dxZObD7CnvSfsckREAIX+uPrjS2cTMePeZ3eFXYqICKDQH1dTywr54PnT+GljE0f742GXIyKi0B9vn15WT1d/nId183QRmQAU+uNs8cxyLppVzr3P7iKhm6eLSMgU+mfAp5fVs6u9h19tOhB2KSKS4xT6Z8CKd53D/CklfPWRjXT2DYZdjojksLRC38xWmNkWM9tmZreP8v7NZtZqZuuDx2dT3htKaV+TyeLPFvmRPP7xExdwsKufrz2yKexyRCSHRcdawcwiwGrgaqAZWGdma9x9ZHo96O6rRvkRve6++PRLPbtdMLOcP3vPXP7v2m1cc945XLlwatgliUgOSqenvxTY5u473H0AeAC4fnzLyk63XTWPheeUcvvPfs+RHg3ziMiZl07ozwBSzzdsDtpG+piZbTCzh81sZkp7oZk1mtnzZvaR0T7AzG4N1mlsbW1Nv/qzTEE0wj9+4gIOdQ9w5yMbwy5HRHJQpg7kPgLUufv5wBPAv6a8N9vdG4AbgW+Z2dyRG7v73e7e4O4NNTU1GSppYnrXjMn8+fvm8e+vtPD4xv1hlyMiOSad0G8BUnvutUHbMe7e7u79weL3gYtT3msJnncATwEXnka9WWHVlfNYNK2Mv/3333OoeyDsckQkh6QT+uuA+WZWb2YxYCXwlrNwzGxayuJ1wOagvcLMCoLX1cAyIOdPX8mP5PGNT17Akd5BvvyL18IuR0RyyJih7+5xYBXwOMkwf8jdN5rZXWZ2XbDa581so5m9CnweuDloPxdoDNrXAl8f5ayfnHTutDK+cNV8Ht2wj//YsC/sckQkR9hEu7NTQ0ODNzY2hl3GGREfSvDR//csLR29/OqL76a6pCDskkTkLGVmLwXHT09IV+SGKBoM8xzti/M/fv6abq0oIuNOoR+yBVNL+dLyBfzytf381U830K0pmEVkHI15Ra6Mv1uvmENPf5x/WruNV/Yc5p9uvJB3Tp8cdlkikoXU058A8vKMLy1/B/d99hK6B+J8dPWz/OCZnRruEZGMU+hPIJfOreKXX3g3V8yv5quPbOJzP2zUefwiklEK/QmmsjjG929q4CsfXsTTW9u45ttP89z29rDLEpEsoTH9CcjM+PSyepbUVfL5+1/hxu8/z8olsygritLTP0T3QPzYc3d/nJ6BIWorivibFQuZP7U07PJFZALTefoTXHd/nDvXbOThl5uJRfIoLogyKRahOBZlUkHyuSgW4YUd7fQMDHHzZXV84f3zKS3MH/fa+uND/O6NNna2dXPNedOYUV407p8pIqNL9zx9hf5ZIpFw8vLsuO+3H+3nH3+1hQfWNVFVXMAd1yzkoxfOOOE2p2IgnuCZ7W08+uo+frVpP119yVNM8wyuOncqf3zJbC6fV53xz51odrd38+C6JhZMLeXDF0wncgb/vImEs765g4F4gvNmTKa44Mz+wt50qIeDXX1cNKsCszP79zyUcA529TFtsjoYIyn0c9SrTR18Zc1G1jd1cPHsCr563Tt514zTO/1zcCjBc9vbeXTDXh7feIAjvYOUFkZZvugcPnT+NOqqi/lpYxMPrmuivXuAuqpJfOqS2Xz84lrKJ8VO+LOHEo7BSX1JtB/t59nt7Ty7vY3XWjpZMLWUy+ZWcdm8qnEPg20Hu1i9dju/WN/C8H3u59YU84X3L+BD500bty87d+e1lk4e2bCXR1/dy94jfUDyy3bB1FIWzyzngpnlLJ5ZzoKppRn/EnJ3ntvezg+e3cWTmw/gDu+YWsotV9Rz/eLpFEQjGf28kQaHEvz8lRa+89R2drR1854FNdx+zULOnVY2rp87rG9wiPtf3MO/vdzCxbMr+PSyOmZXFZ+Rz06XQj+HJRLOwy838w+/fJ1DPQPcuHQW110wnaGEE094ynOCwSEnnkhwtC9OR88gHb2DdPQMcqR3kCO9A3T0DLL/SB9d/XFKCqIsXzSVD54/jcvnV7/tP3p/fIj/fG0/P3puN427D1MQzeO6C6Zz+fxq2o8O0Hq0n9aufg52JZ9bu/po7x6gMBph7pRi5taUMK+mhHlTko/ZVcXEonl098d5cdchnnmjjWe2t7N5XycApQVRFk0vY8uBLjqCm9LUVxdzyZwqLptbxSVzqqgpzczUFhv3HmH12m388rX9FEYjfOqSWdxy+Rxe3nOYbz25la0HjrJgaglfuGoB17zrnIyF/9YDXTzy6l4eeXUvu9p7iOYZ715Qw4cvmEZ5UYz1TR3HHkd6k/tgUizCeTMmM7NyEhWT8imfFKN8Uj4VwXN5UYzK4hhTSgvGrLN3YIifr2/h3md2seVAF5XFMW5cOotZlZO455mdvL6/i5rSAm6+rI4/+oNZY37Jn6z++BA/bWzmu7/ZTvPhXhZNK+O976jhJy/sobNvkI9dVMtfLl8wbl/23f1xfvLCbu5+eidtR/tZeE4p21uPEk84yxdN5ZbL57Ck7sz/xjMahb5wpHeQbz25lR8+t5uhRHp/z8WxCOWTYkwuymdyUT7lk/KpKonxngVTuGJ+NYX56fXoNu3t5Mcv7Obnr7TQMzAEQDTPqCktYEppATWlBdSUFlJTEqOrP862g0fZfvDosR7s8PrTy4vY29FLPOHEInlcPLuCZfOquGxeNefPmEw0kkci4Wze38lz29t5bns7L+w8xNHgyuY51cVMLy9iSmkBU8oKmVpWwJTS5PPUskLKJ+UTi+YRi+SN+h/3lT2HWb12G09uPkhpQZSbLqvjM5fXU1n8ZrglEs5//H4f33pyK9tbu1l4Til/8f4FfOCdU4/9zJ6BOHs7etnb0Rc899J6tJ+hhJNwSLjjwXPCkz3rNw4cZcuBLvIseTrvh8+fzop3nTNqsLo7u9p7WN90mPV7Oni1+QgHOvvo6Bmkd3Bo1L+jWDSP2ZWTqKsupr66mLqqYuqqJ1FXVYwDP35+N/e/uIeOnkHOnVbGp5fVcd0F04/9G3B3fretje/9didPb22lKD/CJxtq+czl9cd6wYNDiaATMcDhnkEOdw8EvynmU1tRxMyKSZQVRd+273sHhrjvxT3c/fR2DnT2s3hmOZ+/ah7ve8cUzIwjPYOsfmob9z6zCzO45fJ6/vS9cynL0LGsrr5Bfvjcbr7/2x0c7hnk8nnV3HblPP5gThUHO/v44XO7+fELu+noGeT82snccnk91543jfxIeCdEKvTlmN3t3TQf7iWSZ0TzjEiekR/Je8tyaWEy5GPRzP6j7eobZG9HHzWlBZQX5Y/Zs+zuj7OjtZttrV1sO3iUXW09zKycxLJ5VTTMrqQoNvaXTnwowWt7O3l2exsbmo5woKuPg539HOzqY3Do+P/e8yNGLJJHfvAlEMkz9h3po3xSPrcsq+dPLqtjctHxQ2Uo4Ty6YS/ffvINdrR1M29KCbFIHnuP9B77TWRYnkFlcQH5ESMvCLy8PMiz5LIZVJcU8MHzpnHNeecwpbRwzD/38fQNDtHRM8jhnuRvbh09A7R1D9B0qIedbd3sautm96EeBuKJt9W4fNE53Lysjj+orzxhb/b1/Z18/7c7+cX6FuIJZ/rkIo70Dh778j2RkoIotRVFwWMSBfl5PNzYTHv3AJfMqeS2K+dz2dyqUT+/6VAP3/jVFn6+fi8Vk/L5/FXz+aM/mH3S/44TCWdgKEFn7yD3vbiHe363k86+OO97Rw2rrpzPxbMr3rZN78AQP3u5mXue2cmO1m6mTS5k5ZJZzKgooqQgQklBPsUFEUoKohQXRCkpjBLNMzp748Fv0m8+OoPnqpIYf3Jp3UnVPkyhLzJCIuF09A5ysKuPA539HAx6wgNDCQbiCQaD54Gh5Ov+eIJF08q4YemskzpYGh9K8Iv1e3mwsYmSgijTywuZXl7EjPIipgePKaUFofYKR0oknL1HetnV1sOu9m66+uJ86PxpzKycdFI/50BnHz9+fjfNh3vfOqQ0KUZ5UXK5rChKZ2+clo4emg/3pjx6aDncS1d/nPcsqGHVlfNYUleZ1ue+1nKE//nYZp7d3k5RfoTC/OSXdiTPiJgRiSSf8/IMHPrjCfrjQ/QPJv+eB4be+oV39aKp3HblPM6vLU9r3z219SDf/+1Onj3Na2qW1lfy0H+79JS2VeiLyFmpb3Ao7WHEVO7Ob7a28tSW1mPHrRIJZ8iTz/HgtQGF+REKonkURCMU5Oe9+Tqax6Vzq075AHFHzwCdvXGO9sfpHohztC943Z98HhzyY0OnZUXRY68nF+VTWph/Wgfg0w19XZwlIhPKqQQ+JC9qfO87pvDed0zJcEXpSx40z+zB7EybOL9fiojIuFPoi4jkkLRC38xWmNkWM9tmZreP8v7NZtZqZuuDx2dT3rvJzN4IHjdlsngRETk5Y47pm1kEWA1cDTQD68xszSg3OH/Q3VeN2LYS+ArQADjwUrDt4YxULyIiJyWdnv5SYJu773D3AeAB4Po0f/4HgCfc/VAQ9E8AK06tVBEROV3phP4MoClluTloG+ljZrbBzB42s5knua2IiJzKd0vbAAAE/ElEQVQBmTqQ+whQ5+7nk+zN/+vJbGxmt5pZo5k1tra2ZqgkEREZKZ3QbwFmpizXBm3HuHu7u/cHi98HLk5322D7u929wd0bampq0q1dRERO0phX5JpZFNgKXEUysNcBN7r7xpR1prn7vuD1R4G/cfdLggO5LwEXBau+DFzs7odO8HmtwO5T/yNRDbSdxvbjSbWdGtV2alTbqTlba5vt7mP2msc8e8fd42a2CngciAD3uPtGM7sLaHT3NcDnzew6IA4cAm4Otj1kZl8j+UUBcNeJAj/Y5rS6+mbWmM6lyGFQbadGtZ0a1XZqsr22tKZhcPfHgMdGtH055fUdwB3H2fYe4J7TqFFERDJEV+SKiOSQbAz9u8Mu4ARU26lRbadGtZ2arK5twk2tLCIi4ycbe/oiInIcWRP6Y00KFyYz22Vmvw8mowv9DjFmdo+ZHTSz11LaKs3siWBivCfM7O33hwunrjvNrCVlMr9rz3RdQR0zzWytmW0ys41m9oWgfSLst+PVFvq+M7NCM3vRzF4Navtq0F5vZi8E/18fNLMzPgn9CWq718x2puy3xWe6tpQaI2b2ipk9Giyf/n5z97P+QfJU0u3AHCAGvAosCruulPp2AdVh15FSz7tJXjvxWkrb/wZuD17fDvzDBKnrTuCvJsA+mwZcFLwuJXntyqIJst+OV1vo+w4woCR4nQ+8AFwCPASsDNq/C/zZBKrtXuDjYf+bC+r6EnAf8GiwfNr7LVt6+qczKVzOcfenSV5Pkep63pw+41+Bj5zRojhuXROCu+9z95eD113AZpLzSE2E/Xa82kLnSUeDxfzg4cCVwMNBe1j77Xi1TQhmVgt8kOQsB1jyzvCnvd+yJfQn+sRuDvzKzF4ys1vDLuY4pnpwVTWwH5gaZjEjrAom87snjOGTkcysDriQZM9wQu23EbXBBNh3wRDFeuAgybm5tgMd7h4PVgnt/+vI2tx9eL/9fbDfvmlmBWHUBnwL+O/A8F3bq8jAfsuW0J/oLnf3i4BrgD83s3eHXdCJePJ3x4nS4/kOMBdYDOwDvhFmMWZWAvwM+At370x9L+z9NkptE2LfufuQuy8mOffWUmBhGHWMZmRtZvYukheaLgSWAJXA35zpuszsQ8BBd38p0z87W0I/rYndwuLuLcHzQeDfSf7Dn2gOmNk0SM6lRLLnEzp3PxD8x0wA3yPEfWdm+SRD9Sfu/m9B84TYb6PVNpH2XVBPB7AWuBQoD+b1ggnw/zWlthXBcJl7chLJHxDOflsGXGdmu0gOV18JfJsM7LdsCf11wPzgyHYMWAmsCbkmAMys2MxKh18Dy4HXTrxVKNYAw7ezvAn4RYi1HDMcqIGPEtK+C8ZT/wXY7O7/J+Wt0Pfb8WqbCPvOzGrMrDx4XUTyDnybSQbsx4PVwtpvo9X2esqXuJEcMz/j+83d73D3WnevI5lnv3b3PyIT+y3so9MZPMp9LcmzFrYDfxt2PSl1zSF5NtGrwMaJUBtwP8lf9wdJjgveQnK88L+AN4AngcoJUtePgN8DG0gG7LSQ9tnlJIduNgDrg8e1E2S/Ha+20PcdcD7wSlDDa8CXg/Y5wIvANuCnQMEEqu3XwX57DfgxwRk+YT2A9/Lm2Tunvd90Ra6ISA7JluEdERFJg0JfRCSHKPRFRHKIQl9EJIco9EVEcohCX0Qkhyj0RURyiEJfRCSH/H/r0BH34tWF7gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fecfa4b9be0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def train_network(num_epochs, num_steps, state_size=4, verbose=True):\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        training_losses = []\n",
    "        for idx, epoch in enumerate(gen_epochs(num_epochs, num_steps)):\n",
    "            training_loss = 0\n",
    "            training_state = np.zeros((batch_size, state_size))\n",
    "            if verbose:\n",
    "                print(\"\\nEPOCH\", idx)\n",
    "            for step, (X, Y) in enumerate(epoch):\n",
    "                tr_losses, training_loss_, training_state, _= \\\n",
    "                    sess.run([losses, \n",
    "                              total_loss,\n",
    "                              final_state,\n",
    "                              train_step],\n",
    "                            feed_dict={x:X, y:Y, init_state:training_state})\n",
    "                training_loss += training_loss_\n",
    "                if step % 100 == 0 and step > 0:\n",
    "                    if verbose:\n",
    "                        print(\"Average loss at step\", step,\n",
    "                              \"for last 100 steps:\", training_loss/100)\n",
    "                    training_losses.append(training_loss/100)\n",
    "                    training_loss = 0\n",
    "    return training_losses\n",
    "training_losses = train_network(10, num_steps, state_size=state_size)\n",
    "plt.plot(training_losses)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
